# -*- coding: utf-8 -*-
"""
Created on Fri May 29 13:17:48 2020

@author: liqin
"""
#USAGE:%run splicev10.py -bsj FGFR2_BS.bed -sj FGFR2_LS.bed -gtf FGFR2.gtf -g FGFR2 -category all -c gray

from __future__ import print_function
import sys
import os
import math
import matplotlib
from numpy import arange, linspace, sqrt, random
import numpy as np
import argparse
from collections import defaultdict, namedtuple, Counter
import webbrowser
import re
from itertools import cycle


matplotlib.use('Agg')
from matplotlib.path import Path
import matplotlib.patches as patches
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D


def parse():
    parser = argparse.ArgumentParser(description='Plot transcript')
    parser.add_argument("-bsj", nargs='*', help="Path to backsplice junction bed-formatted files (listed in the same order as the input bam files)")
    parser.add_argument("-stranded", help="If strand-specific sequencing, indicate 'forward' if upstream reads are forward strand, otherwise indicate 'reverse' (True-Seq is 'reverse').")
    parser.add_argument("-sj", nargs='*', help="Path to canonical splice junction bed-formatted files (listed in the same order as the input bam files)" )
    parser.add_argument("-category", help="Sample category, all or high_confidence,default is high_confidence")
    parser.add_argument("-TS", help="whether junction is tumor-specific")
    parser.add_argument("-anno", help="whether junction is annotated")
    parser.add_argument("-is", "--intron-scale", type=float, help="The factor by which intron white space should be reduced")
    #    parser.add_argument("-b", "--bam",action='store_true')
    parser.add_argument("-c", "--color", default="#C21807", type=str, help="Exon color. Hex colors (i.e. \"\#4286f4\". For hex, an escape \"\\\" must precede the argument), RGB (i.e. 211,19,23) or names (i.e. \"red\")")
    parser.add_argument("-t", "--transcript",  type=str, help='Name of transcript to plot')
    parser.add_argument("-g", "--gene", type=str, help='Name of gene to plot (overrides "-t" flag). Will plot the longest transcript derived from that gene')
    parser.add_argument("-f", "--filter", default=0, type=int, help='Filter out sj and circles that have fewer than this number of counts.')
    parser.add_argument("-n", "--normalize", action='store_true', help='Normalize coverage between samples')
    parser.add_argument("-rc", "--reduce_canonical", type=float, help='Factor by which to reduce canonical curves')
    parser.add_argument("-rbs", "--reduce_backsplice", type=float, help='Factor by which to reduce backsplice curves')
    parser.add_argument("-ro", "--repress_open", action='store_true', help='Do not open plot in browser (only save it)')
    parser.add_argument("-en", "--exon_numbering", action='store_true', help='Label exons')
    parser.add_argument("-gtf", required=True, help="Path to gtf file")
    parser.add_argument('-rnabp', nargs='*', help="List of RNABPs to plot.")
    parser.add_argument('-rnabpc', nargs='*', help="Colors to use for RNABPs")
    parser.add_argument("-fa", help="Path to fasta file")
    parser.add_argument('-format', help="Output image format ('SVG', 'PDF', 'PNG', 'JPG', 'TIFF')", default='PNG')
    parser.add_argument("-alu", help="Path to Alu bed file")
    args = parser.parse_args()
    # Check GTF, BAM, SJ, and BSJ paths.
    # Check GTF, BAM, SJ, and BSJ paths.
    if not os.path.exists(args.gtf):
        sys.exit("GTF: {} was not found".format(args.gtf))

    if args.sj:
        for path in args.sj:
            if not os.path.exists(path):
                sys.exit("Splice junction file {} was not found.\n If you want to obtain splice junction reads from the input bam files, don't specify a splice junction file.".format(path))

    if args.bsj:
        for path in args.bsj:
            if not os.path.exists(path):
                sys.exit("Backsplice junction file {} was not found.\n If you want to obtain backsplice junction reads from the input bam files, don't specify a backsplicejunction file".format(path))

    if not (args.gene or args.transcript):
        sys.exit('Either a gene or a transcript must be specified. (ex. "-t ENST00000390665" or "-g EGFR")')

    #  Color: if not in rgb format already (i.e. 123,11,0)
    if ',' in args.color:
        color = list(map(int, args.color.split(',')))
        if len(color) != 3:
            args.color = to_rgb(args.color)
        else:
            args.color = tuple([c / 255 for c in color])  # Matplotlib requires rgb values to be between 0 and 1 (rather than 0-255)
    else:
        args.color = to_rgb(args.color)
    return args
def calc_bez_max(p0, p1, p2, p3=None, t=0.5, quadratic=False):

    if quadratic:
        x = (1 - t) * (1 - t) * p0.x + 2 * (1 - t) * t * p1.x + t * t * p2.x
        y = (1 - t) * (1 - t) * p0.y + 2 * (1 - t) * t * p1.y + t * t * p2.y

        # Cubic
    else:
        x = (1-t)*(1-t)*(1-t)*p0.x + 3*(1-t)*(1-t)*t*p1.x + 3*(1-t)*t*t*p2.x + t*t*t*p3.x
        y = (1-t)*(1-t)*(1-t)*p0.y + 3*(1-t)*(1-t)*t*p1.y + 3*(1-t)*t*t*p2.y + t*t*t*p3.y

    return x, y


# To adjust text if collusion occurs.
class Box:

    def __init__(self, x0, x1, y0, y1):
        self.x0 = x0
        self.x1 = x1
        self.y0 = y0
        self.y1 = y1


def intersect(boxa, boxb, subtract):

    xextra = (boxa.x1 - boxa.x0)/4
    yextra = (boxa.y1 - boxa.y0)/4
    if boxa.x0 -xextra<= boxb.x0 <= boxa.x1+xextra or boxa.x0-xextra <= boxb.x1 <= boxa.x1+xextra:

        if boxa.y0-yextra <= boxb.y0 <= boxa.y1+yextra or boxa.y0-yextra <=boxb.y1<=boxa.y1+yextra:
            if not subtract:
                boxb.y0 += .02
                boxb.y1 += .02

            else:
                boxb.y0 -= .02
                boxb.y1 -= .02

            if random.randint(2) == 1:
                boxb.x0 += xextra/10
                boxb.x1 += xextra/10
            else:
                boxb.x0 -= xextra/10
                boxb.x1 -= xextra/10

            return intersect(boxa, boxb, subtract)

    ax = plt.gca()
    ymax = ax.get_ylim()[1]
    plt.ylim([plt.ylim()[0], max([ymax, boxb.y1])])

    return boxb

def draw_exons(ax, exon_coords, y, height, colors, gene_length):


    new = []

    for e in exon_coords:
        print(e)
        x=0
        new.append((e[0], e[1],e[2], 'e'))
    i=0
    while i < len(new):
        h =height
        arr_len=gene_length*0.02
        vertices, codes = [], []
        vertices.append((new[i][0], y + h))
        codes.append(Path.MOVETO)
        vertices.append((new[i][1], y + h))
        codes.append(Path.LINETO)
        if i + 1 < len(new) and new[i + 1][0] == new[i][1]:
            h2 =height
            vertices.append((new[i + 1][0], y + h))
            codes.append(Path.LINETO)
            vertices.append((new[i + 1][1], y + h))
            codes.append(Path.LINETO)
            vertices.append((new[i + 1][1], y - h))
            codes.append(Path.LINETO)
            vertices.append((new[i + 1][0], y- h))
            codes.append(Path.LINETO)
            vertices.append((new[i + 1][0], y - h))
            codes.append(Path.LINETO)
            vertices.append((new[i][0], y - h))
            codes.append(Path.LINETO)
            vertices.append((new[i][0], y + h))
            codes.append(Path.CLOSEPOLY)
            i+=1
        else:
            vertices.append((new[i][1], y - h + height))
            codes.append(Path.LINETO)
            vertices.append((new[i][0], y - h + height))
            codes.append(Path.LINETO)
            vertices.append((new[i][0], y + h))
            codes.append(Path.CLOSEPOLY)

        i+=1

        p = Path(vertices, codes)
        c = colors.pop(0)
        patch = patches.PathPatch(p,facecolor = c, lw=0.3, ec='k')

        ax.add_patch(patch)
        if(new[0][2]=="+"):
            ax.annotate(s='',xy=(new[0][0]+arr_len,y+1.5*h), #注解
                        xytext=(new[0][0],y+h),
                        arrowprops=dict(arrowstyle="-|>",#"-|>"代表箭头头上是实心的
                                        connectionstyle="angle,rad=1",fc='r')#rad代表箭头是否是弯的，+-定义弯的方向
                        )
        if(new[0][2]=="-"):
            ax.annotate(s='',xy=(new[-1][0]-arr_len,y+1.5*h), #注解
                        xytext=(new[-1][0],y+h),
                        arrowprops=dict(arrowstyle="-|>",#"-|>"代表箭头头上是实心的
                                        connectionstyle="angle,rad=1",fc='r')#rad代表箭头是否是弯的，+-定义弯的方向
                        )


#notice bezier_offset is the depth of the circle junctions
def draw_backsplice(ax, start, stop, y, y2, adjust, bezier_offset, gene_size,lw,plot=True):
    ''' Takes a start and a stop coordinate and generates a bezier curve underneath exons (for circle junctions).
        "bezier_offset" controls the depth of the circle junctions on the plot.'''


    ylim = ax.get_ylim()
    space = ylim[1] - ylim[0]
    size_adjust = gene_size / 20

    Point = namedtuple('Point', ['x', 'y'])
    p0 = Point(start, y)
    #    p1 = Point(start - size_adjust, y - (bezier_offset * space) -  adjust)
    #    p2 = Point(stop + size_adjust, y2 - (bezier_offset * space) - adjust)
    p1 = Point(start - size_adjust, y - (bezier_offset * space))
    p2 = Point(stop + size_adjust, y2 - (bezier_offset * space))
    p3 = Point(stop, y2)

    if not plot:
        return p0,p1,p2,p3

    verts = [
        p0,
        p1,
        p2,
        p3,
    ]

    codes = [
        Path.MOVETO,
        Path.CURVE4,
        Path.CURVE4,
        Path.CURVE4,
    ]

    path = Path(verts, codes)
    patch = patches.PathPatch(path, facecolor='none', lw=lw, alpha=.7, ec='0')
    ax.add_patch(patch)


def draw_canonical_splice(ax, start, stop, y, y2, adjust, bezier_offset,lw,anno_type,as_type,plot=True):
    ''' Takes a start and a stop coordinate and generates a bezier curve underneath exons (for circle junctions).
        "bezier_offset" controls the depth of the circle junctions on the plot. '''

    xlim  = ax.get_xlim()
    xspace = xlim[1] -xlim[0]
    length = sqrt(abs((stop - start))/ xspace)
    Point = namedtuple('Point', ['x', 'y'])
    p0 = Point(start, y)
    #    p1 = Point(start + .2 * (stop-start), y  + (length*bezier_offset) + adjust)
    #    p2 = Point(stop - .2 * (stop-start), y2 + (length*bezier_offset) + adjust)
    p1 = Point(start + 0.2 * (stop-start), y  + (length*bezier_offset))
    p2 = Point(stop - 0.2 * (stop-start), y2 + (length*bezier_offset))
    p3 = Point(stop, y2)
    if not plot:
        return p0,p1,p2,p3

    verts = [
        p0,
        p1,
        p2,
        p3,
    ]

    codes = [
        Path.MOVETO,
        Path.CURVE4,
        Path.CURVE4,
        Path.CURVE4,
    ]

    path = Path(verts, codes)
    if anno_type=="yes":
        patch = patches.PathPatch(path, facecolor='none', lw=lw,ec='0')
    if anno_type=="no":
        patch = patches.PathPatch(path,facecolor='none', lw=lw,ec='0')
        patch.set_edgecolor('red')
    if as_type!="nTS":
        print(anno_type)
        patch = patches.PathPatch(path,facecolor='none', lw=lw,ec='0')
        patch.set_edgecolor('blue')
    ax.add_patch(patch)


#true plot backsplice curves
def plot_circles(ax, coordinates, y, cds_range, gene_size, numbering=False, fig=None):
    '''Takes list of coordinate tuples (start, stop, counts) and plots backsplice curves using draw_backsplice()'''
    y1 = y - 0.2*0.5
    texts, boxes = [], []
    for start, stop, counts,circ_type in coordinates:
        if counts != 0:

            if start > stop:
                stop, start = start, stop

            y2 = y - 0.2*0.5
            step = 1.0 /counts
            factor = .25
            #            for num in arange(0.0, factor, factor * step):
            num=1
            lw=0.3*counts
            draw_backsplice(ax=ax, start=start, stop=stop, y=y1,y2=y2, adjust=num, bezier_offset=0.2, gene_size=gene_size,lw=lw)

            if numbering:
                p0, p1, p2, p3 = draw_backsplice(ax=ax, start=start, stop=stop, y=y1,y2=y2, adjust=num, bezier_offset=0.2, gene_size=gene_size, plot=False)
                x_mid, y_mid = calc_bez_max(p0, p1, p2, p3)
                text = plt.annotate(str(counts), (x_mid, y_mid), ha='center', va='top', fontsize=8, xytext=(x_mid, y_mid-.3), arrowprops={'arrowstyle':'-','alpha':.05, 'lw':1})
                texts.append(text)
                plt.draw()
                r = fig.canvas.get_renderer()

    if numbering:
        for text in texts:
            extent = text.get_window_extent(r).transformed(ax.transData.inverted())
            box = Box(extent.xmin, extent.xmax, extent.ymin, extent.ymax)
            boxes.append(box)

        for indexa, boxa in enumerate(boxes):
            for indexb, boxb in enumerate(boxes):
                if indexa != indexb:
                    new_boxb = intersect(boxa, boxb, subtract=True)
                    boxes[indexb] = new_boxb

        for text, box in zip(texts, boxes):
            x_mid = (box.x0 + box.x1)/2
            text.set_position((x_mid, box.y1))

        ymin = ax.get_ylim()[0]
        plt.ylim([min([ymin, min(i.y0 for i in boxes)]), ax.get_ylim()[1]])

#?
def plot_SJ_curves(ax, coordinates, y, cds_range, numbering=False, fig=None):
    '''Takes list of coordinate tuples (start, stop, counts) and plots backsplice curves using draw_backsplice()'''
    y1 = float(y)
    texts, boxes = [], []
    for start, stop, counts,anno_type,as_type in coordinates:
        if counts != 0:
            step = 1.0 /(counts)
            if start > stop:
                stop, start = start, stop

            y2 = float(y)
            y1 = float(y)

            #            for num in arange(0.0, 0.1, 0.1 * step):
            lw=0.3*counts
            num=1
            draw_canonical_splice(ax=ax, start=start, stop=stop, y=y1,y2=y2, adjust=num, bezier_offset=2.5,lw=lw,anno_type=anno_type,as_type=as_type)

            if numbering:
                p0, p1, p2, p3 = draw_canonical_splice(ax=ax, start=start, stop=stop, y=y1, y2=y2,adjust=num, bezier_offset=2.5, plot=False)
                x_mid, y_mid = calc_bez_max(p0, p1, p2, p3)
                text = plt.annotate(str(counts), (x_mid, y_mid), ha='center', va='bottom', fontsize=8, xytext=(x_mid, y_mid+.1), arrowprops={'arrowstyle':'-','alpha':.05, 'lw':1})
                texts.append(text)
                plt.draw()
                r = fig.canvas.get_renderer()

    if numbering:
        for text in texts:
            extent = text.get_window_extent(r).transformed(ax.transData.inverted())
            box = Box(extent.xmin, extent.xmax, extent.ymin, extent.ymax)
            boxes.append(box)

        for indexa, boxa in enumerate(boxes):
            for indexb, boxb in enumerate(boxes):
                if indexa != indexb:
                    new_boxb = intersect(boxa, boxb, subtract=False)
                    boxes[indexb] = new_boxb

        for text, box in zip(texts, boxes):
            x_mid = (box.x0 + box.x1)/2
            text.set_position((x_mid, box.y0))

        ymax = ax.get_ylim()[1]
        plt.ylim([ax.get_ylim()[0], max([ymax, max(i.y1 for i in boxes)])])


def scale_introns(coords, scaling_factor):
    '''Reduces intron size, returns new exon coordinates'''

    if scaling_factor <= 0:
        print("Intron scaling factor must be > 0. Plotting without scaling.")
        return coords

    newcoords = []
    newcoords.append(coords[0])

    for i in range(1, len(coords)):
        length = coords[i][1] - coords[i][0]
        exonEnd = coords[i-1][1]
        nextExonStart = coords[i][0]
        intron = (nextExonStart - exonEnd) / scaling_factor
        left = newcoords[i-1][1] + intron
        right = left + length
        newcoords.append((left, right))

    return newcoords
def to_rgb(color):
    '''Converts hex or color name to rgb. Coverage is set up to be represented by 'alpha' of rgba'''

    colordict = {
        'brown': '#663300',
        'red': '#FF0000',
        'blue': '#0000FF',
        'green': '#006600',
        'yellow': '#FFFF00',
        'purple': '#990099',
        'black': '#000000',
        'white': '#FFFFFF',
        'orange': '#FF8000',
        'gray':'#A9A9A9'
    }

    if type(color) != str:
        print("Invalid color input: %s\n Color is set to red" % color)
        return (1,0,0)

    if color[0] != '#' or len(color) != 7:
        if color in colordict:
            color = colordict[color.lower()]
        else:
            print("Invalid color input: %s\n Color is set to red" % color)
            return (1,0,0)
    try:
        rgb = tuple([int(color[i:i+2], 16)/255.0 for i in range(1, len(color), 2)])

    except ValueError:
        print("Invalid hex input: %s. Values must range from 0-9 and A-F.\n Color is set to red" % color)
        return (1,0,0)

    return rgb


def junction_file_parse1(bed_path, chromosome, upstream, downstream, strand=None, min_freq=0,min_count=0):

    junctions = []
    with open(bed_path, "r") as input_file:
        for line in input_file:
            if line.startswith("ch"):
                line = line.strip().split("\t")
                _, start, stop, bed_strand=line[:4]
                counts=line[5]
                frequency=line[6]
                frequency=float(frequency)
                ts_type=line[7]
                anno_type=line[8]
                counts =float(counts)
                if counts==0:
                    continue
                else:
                    counts=math.log(counts,2)
                args = parse()
                if args.TS=="TS":
                    if ts_type=="nTS":
                        continue
                if args.anno=="un":
                    if anno_type=="yes":
                        print(start,":yes is un")
                        continue
                if frequency < min_freq and counts<min_count:
                    continue
                #                start = int(start) + 1  # Bed format is 0 indexed
                start = int(start)
                stop = int(stop)
                if start > stop:
                    start, stop = stop, start
                if not (upstream <= start <= downstream and upstream <= stop <= downstream):
                    continue
                if strand:
                    if strand == bed_strand:
                        junctions.append((start, stop, counts))
                else:
                    junctions.append((start, stop, counts,anno_type,ts_type))

    return junctions
def junction_file_parse2(bed_path, chromosome, upstream, downstream, strand=None, min_freq=0,min_count_circ=0):

    junctions = []
    with open(bed_path, "r") as input_file:
        for line in input_file:
            if line.startswith("ch"):
                line = line.strip().split("\t")
                _, start, stop, bed_strand, gene,counts,frequency,as_type = line[:8]
                counts =float(counts)
                if counts==0:
                    continue
                else:
                    counts=math.log(counts,2)
                frequency=float(frequency)
                args = parse()
                if args.anno=="un":
                    continue
                if frequency <= min_freq or counts<=min_count_circ:
                    continue
                #                start = int(start) + 1  # Bed format is 0 indexed
                start = int(start)
                stop = int(stop)
                if start > stop:
                    start, stop = stop, start
                if not (upstream <= start <= downstream and upstream <= stop <= downstream):
                    continue
                if strand:
                    if strand == bed_strand:
                        junctions.append((start, stop, counts))
                else:
                    junctions.append((start, stop, counts,as_type))

    return junctions
def plot_bp(ax, positions, y, color, transcript_length):


    ax_len = transcript_length
    h = 9 /ax_len

    patch_width = ax_len / 200
    # patch_height = 1900 / ax_len
    patch_height = h * patch_width

    for position in positions:
        if not color:
            ellipse = patches.Ellipse((position, y), patch_width, patch_height,  alpha=.7)
        else:
            ellipse = patches.Ellipse((position, y), patch_width, patch_height, facecolor=color,  alpha=.4)

        ax.add_patch(ellipse)
        ellipse2 = patches.Ellipse((position, y), patch_width, patch_height, fill=False, edgecolor='k',linewidth=.3)
        ax.add_patch(ellipse2)


def exons(path, gene, transcript=False):
    '''Given a gene or transcript, returns exon coordinates from gtf file -> (chromosome, 5prime coord, 3prime coord, strand) for each exon
        If transcript is False, searches gtf file for gene name , determines the longest daughter transcript, returns exon coordinates.
        Otherwise, directly returns transcript specific exon coordinates. '''

    cds_dict = defaultdict(list)
    transcript_dict = defaultdict(list)

    # Search by transcript ID or gene name.
    if transcript:
        prog = re.compile('transcript_id "%s"' % gene, flags=re.IGNORECASE)
    else:
        prog = re.compile('gene_name "%s"' % gene, flags=re.IGNORECASE)

    Exon = namedtuple('Exon', ['chromosome', 'source', 'feature', 'start', 'stop', 'score', 'strand', 'frame'])
    with open(path) as gtf:
        for line in gtf:

            # Skip header lines
            if line[0] == '#':
                continue

            # Only lines with gene name (ignore case)
            if prog.search(line):
                line = line.split('\t')
                attributes = line[-1]
                chromosome, source, feature, start, stop, score, strand, frame = line[:-1]
                exon = Exon(chromosome, source, feature, start, stop, score, strand, frame)
                if exon.feature.lower() == 'exon':
                    transcript_dict['transcript'].append((exon.chromosome, int(exon.start), int(exon.stop), exon.strand))

                elif exon.feature.lower() == 'cds':
                    cds_dict['transcript'].append(((int(exon.start), int(exon.stop))))

    if len(transcript_dict) == 0:
        sys.exit("Gene {} not found in gtf file.".format(gene))


    # Exon coordinates of longest transcript.
    exon_info = transcript_dict['transcript']

    coordinates = [(start, stop,strand) for _, start, stop,strand in exon_info]
    coordinates=list(set(coordinates))
    coordinates.sort(key = lambda x:x[0])
    chromosome = set([i[0] for i in exon_info])
    if len(chromosome) == 1:
        chromosome = chromosome.pop()
    else:
        sys.exit("Gene {} found on more than one chromosome. Please fix GTF file.".format(gene))

    strand = set([i[3] for i in exon_info])

    if len(strand) == 1:
        strand = strand.pop()
    else:
        sys.exit("Gene {} found on both DNA strands. Please fix GTF file.".format(gene))
    cds = cds_dict['transcript']
    cds.sort(key = lambda x:x[0])

    return chromosome, coordinates, strand, cds

def plot_coverage_curve(ax, x_vals, y_vals, y_bottom, y_top, direction='above',height_factor=.6):

    y_middle = (y_top + y_bottom) / 2

    if direction == 'above':
        y_range = y_top - y_middle
    else:
        y_range = y_bottom - y_middle
    y_range *= height_factor
    # for i,j in zip(x_vals, y_vals):
    #     print("%i: %i"%(i,j))
    if np.max(y_vals) > 0:
        y_vals = y_middle + ((np.array(y_vals)/np.max(y_vals)) * y_range)
        plt.plot(x_vals, y_vals, color='k', linewidth=.2)
        ax.fill_between(x_vals, y_middle, y_vals, color='0.75', interpolate=False, linewidth=.2, edgecolor='k' )
def get_coverage(chromosome, start, stop, strand=None, rev=False, average=True):
    '''Takes AlignmentFile instance and returns coverage and bases in a dict with position as key'''

    length = stop - start
    extra = length // 20
    start -= extra
    stop += extra
    length2 = stop - start
    coverage=[1000]*length2
    if average:
        return np.mean(coverage)
    return list(range(start, stop)), coverage

def main():

    args = parse()

    if args.gene:
        chromosome, exon_coordinates, strand, cds_coordinates  = exons(path=args.gtf, gene=args.gene)
    else:
        chromosome, exon_coordinates, strand, cds_coordinates = exons(path=args.gtf, gene=args.transcript, transcript=True)

    transcript_start = min(i[0] for i in exon_coordinates)
    transcript_stop = max(i[1] for i in exon_coordinates)
    transcript_len = transcript_stop - transcript_start
    if args.intron_scale:
        factor = args.intron_scale
        scaled_coords = scale_introns(exon_coordinates, factor)
        cds_coordinates = scale_coords(exon_coordinates, scaled_coords, cds_coordinates)

    if cds_coordinates:
        cds_range = [min(i[0] for i in cds_coordinates), max(i[1] for i in cds_coordinates)]

    else:
        cds_range = [0,0]
    cds_range.sort()

    if args.stranded:
        junction_strand = strand
        if args.stranded == 'reverse':
            rev = True
        else:
            rev = False

        new_strand = strand

    else:
        rev = False
        new_strand = None
        junction_strand = None

    samples = []
    min_freq =5
    min_count=5
    min_freq_circ=10
    min_count_circ=1
    for s in range(1):
        name =args.gene
        if args.category=="all":
            if args.sj:
                canonical = junction_file_parse1(args.sj.pop(0), chromosome, transcript_start, transcript_stop, junction_strand,min_freq=0,min_count=0)
            circle=[(transcript_start,transcript_stop,0,'circ'),(transcript_start,transcript_stop,0,'circ')]
            if args.bsj:
                circle = junction_file_parse2(args.bsj.pop(0), chromosome, transcript_start, transcript_stop, junction_strand,min_freq=0,min_count_circ=0)
        else:
            if args.sj:
                canonical = junction_file_parse1(args.sj.pop(0), chromosome, transcript_start, transcript_stop, junction_strand, min_freq=min_freq,min_count=min_count)
            if args.bsj:
                circle = junction_file_parse2(args.bsj.pop(0), chromosome, transcript_start, transcript_stop, junction_strand,min_freq=min_freq_circ,min_count_circ=min_count_circ)
        coverage=[]

        x_fill, y_fill = get_coverage(chromosome, transcript_start, transcript_stop, strand=new_strand, rev=rev, average=False)
        for start, stop,strand in exon_coordinates:
            coverage.append(get_coverage(chromosome, start, stop, strand=new_strand, rev=rev, average=True))

        if args.intron_scale:
            x_fill = [transform(exon_coordinates, scaled_coords, i) for i in x_fill]
            canonical = scale_coords(exon_coordinates, scaled_coords, canonical)
            circle = scale_coords(exon_coordinates, scaled_coords, circle)


        if args.reduce_backsplice:
            if args.reduce_backsplice <= 0:
                print("-rbs/ --reduce_backsplice must be > 0. Not reducing backsplice junctions.")
            else:
                circle = [(i, j, k // args.reduce_backsplice) for i, j, k in circle]

        samples.append([name, canonical, circle, coverage, x_fill, y_fill])
        #print(len(x_fill),len(y_fill))

    if args.intron_scale:
        exon_coordinates = scaled_coords

    if args.normalize:
        highest = 0

        for index in range(len(samples)):
            coverage = samples[index][3]
            max_coverage = max(coverage)
            if max_coverage > highest:
                highest = max_coverage

    for index in range(len(samples)):
        coverage = samples[index][3]
        if args.normalize:
            max_coverage = highest * 1.1
        else:
            max_coverage = max(coverage) * 1.1
        if max_coverage != 0:
            color = [args.color + (i / max_coverage, ) for i in coverage]
        else:
            color = [args.color + (0, ) for i in coverage]

        samples[index] += [color]

    if args.normalize:
        highest = 0
        for index in range(len(samples)):
            coverage = samples[index][5]
            max_coverage = max(coverage)
            if max_coverage > highest:
                highest = max_coverage

    for index in range(len(samples)):
        coverage = samples[index][5]
        #        print(coverage)
        if args.normalize:
            max_coverage = highest
        else:
            max_coverage =1000
        if max_coverage != 0:
            coverage = [i / max_coverage for i in coverage]

        samples[index][5] = coverage


    # Plot for each sample
    num_plots =1
    fig = plt.figure(figsize=(15, 4 * num_plots), dpi = 300)
    for i in range(len(samples)):
        name, canonical, circle, _, x_fill, y_fill, colors = samples[i]

        # Center the plot on the canvas
        ax = plt.subplot(num_plots, 1, i+1)
        ybottom = height = 0.5
        ytop = ybottom + height
        if args.alu:
            alus = alu_file_parse(args.alu, chromosome, transcript_start, transcript_stop)
            plot_coords_right, plot_coords_left = [], []
            for alu_start, alu_stop, alu_strand in alus:
                mid_point = (alu_start + alu_stop) / 2
                if alu_strand == '+':
                    plot_coords_right.append(mid_point)
                else:
                    plot_coords_left.append(mid_point)

            ax.plot(plot_coords_left, [-0.5]*len(plot_coords_left), 'k<',ms=5)
            ax.plot(plot_coords_right, [-.5] *len(plot_coords_right), 'k>', ms=5)

        sequence = ''
        if args.rnabp:
            try:
                fasta_path = args.fa
                sequence = read_fasta(fasta_path, chromosome, transcript_start, transcript_stop, strand)

            except:
                print("RNA binding protein(s), %s, supplied. Genome FASTA file is also required (-fa)" % ' '.join(args.rnabp))


            y_bp_adjust = 0.49
            if args.rnabpc:
                rnabpc = cycle(args.rnabpc)
                cmap = [next(rnabpc) for i in range(len(args.rnabp))]
                color_ind = 0
                print(args.rnabpc)
                print(cmap)
            else:
                cmap = plt.get_cmap('Set1').colors
                color_ind = 4

            if sequence:
                for bp in args.rnabp:
                    bp_position_list = bp_positions(bp, sequence, transcript_start)

                    if args.intron_scale:
                        bp_position_list = [transform(exon_coordinates, scaled_coords, i) for i in bp_position_list]
                    plot_bp(ax, bp_position_list, y_bp_adjust * (ybottom + ytop), cmap[color_ind], transcript_len)
                    y_bp_adjust -= .01
                    print(y_bp_adjust)
                    color_ind = (color_ind + 1) % 9

        # Calculated again here in case user requests intron scaling.
        transcript_start = min([int(i[0]) for i in exon_coordinates])
        transcript_stop = max([int(i[1]) for i in exon_coordinates])
        gene_length = transcript_stop - transcript_start

        # Add room on left and right of plot.
        x_adjustment = 0.05 * gene_length

        # Add room on top and bottom of plot. Include enough space here, otherwise curves will exceed the ax limits.
        y_adjustment = 4 * (ytop * height)

        xmin = transcript_start - x_adjustment
        xmax = transcript_stop + x_adjustment
        ymin = ybottom - y_adjustment
        ymax = ytop + y_adjustment
        ax.set_xlim([xmin, xmax])
        ax.set_ylim([ymin, ymax])
        plt.show()

        # Turn off axis labeling.
        ax.axes.get_yaxis().set_visible(False)
        ax.axes.get_xaxis().set_visible(False)

        # Plot.
        print("start plot")
        draw_exons(ax=ax, exon_coords=exon_coordinates, y=ybottom, height=height, colors=colors,gene_length=gene_length)
        print("draw_exons has finished")
        #        plot_coverage_curve(ax=ax, x_vals=x_fill, y_vals=y_fill, y_bottom=ybottom, y_top=ytop)
        if args.TS or args.anno:
            print("only show Tumor-specific junctions or unannotated junctions")
        else:
            plot_circles(ax=ax, coordinates=circle, gene_size=gene_length,y=ybottom + .2 *height, cds_range=cds_range )
            print("circle has finished")
        plot_SJ_curves(ax=ax, coordinates=canonical, y=ytop, cds_range=cds_range)

        print("SJ_curve has finished")
        #        plot_circles(ax=ax, coordinates=circle, gene_size=gene_length,y=ybottom + .2 *height, cds_range=cds_range )
        #        print("circle has finished")


        # Replace special characters with spaces and plot sample name above each subplot.
        name = re.sub(r'[-_|]',' ', name)
        ax.set_title(name)
        ax.text(0.2, 0.2, r'linear', fontsize=15)

    #plt.subplots_adjust(hspace=0.4, top=0.85, bottom=0.1)
    if args.gene:
        title = args.gene
    else:
        title = args.transcript
    title = args.gene
    outformat = args.format.lower()
    plt.tight_layout()
    if args.TS or args.anno:
        print("only show Tumor-specific junctions or unannotated junctions")
    else:
        custom_lines = [Line2D([0], [0], color="red", lw=4),
                        Line2D([0], [0], color="blue", lw=4)]
        plt.legend(custom_lines, ['Unannotated', 'Tumor-specific'],loc='upper right',handlelength=0.9)
        text_start=transcript_start - x_adjustment*0.9
        circ_text_height=ymin+height*2.5
        linear_text_height=ymax-height*2.5
        plt.text(text_start,circ_text_height,"back splicing")
        plt.text(text_start,linear_text_height,"linear splicing")
    plt.savefig("{title}.{extension}".format(title=title, extension=outformat))
    html_str = '''
    <html>
    <body>
    <img src="{title}.{extension}" alt="Loading {title}.{extension}..">
    </body>
    </html>
    '''
'''
    with open("%s.html" % title, "w") as html:
        html.write(html_str.format(title=title, extension=outformat))

    if not args.repress_open:
        webbrowser.open('file://' + os.path.realpath("%s.html" % title))
'''
if __name__ == '__main__':
    main()



